# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests that show that DistributionStrategy works with optimizer v2."""

import numpy as np
import tensorflow.compat.v2 as tf
from absl.testing import parameterized

import keras
from keras.optimizers.legacy import adam
from keras.optimizers.legacy import gradient_descent


def get_model():
    x = keras.layers.Input(shape=(3,), name="input")
    y = keras.layers.Dense(4, name="dense")(x)
    model = keras.Model(x, y)
    return model


class MirroredStrategyOptimizerV2Test(tf.test.TestCase, parameterized.TestCase):
    @tf.__internal__.distribute.combinations.generate(
        tf.__internal__.test.combinations.combine(
            distribution=[
                tf.__internal__.distribute.combinations.central_storage_strategy_with_two_gpus,  # noqa: E501
            ],
            mode=["graph", "eager"],
        )
    )
    def testKerasOptimizerWithUnequalInput(self, distribution):
        with distribution.scope():
            var = tf.Variable(
                2.0, name="var", aggregation=tf.VariableAggregation.SUM
            )
            optimizer = adam.Adam(learning_rate=0.01, beta_1=0.2, beta_2=0.2)
            all_vars = []

            def model_fn():
                def loss_fn():
                    replica_id = _replica_id()
                    return tf.cast(replica_id + 1, dtype=tf.float32) * 0.5 * var

                train_op = optimizer.minimize(loss_fn, var_list=[var])

                return train_op, optimizer

            def train_fn():
                (
                    train_op,
                    optimizer,
                ) = distribution.extended.call_for_each_replica(model_fn)
                if not all_vars:
                    all_vars.append(var)
                    all_vars.append(optimizer.get_slot(var, "m"))
                    all_vars.append(optimizer.get_slot(var, "v"))
                return distribution.group(train_op)

            if not tf.executing_eagerly():
                with self.cached_session() as sess:
                    train_fn = sess.make_callable(train_fn())
            self.evaluate(tf.compat.v1.global_variables_initializer())

            # first step.
            train_fn()
            # var(1) = var(0) - lr * m(1) * sqrt(1 - beta2) / sqrt(v(1)) / (1 -
            # beta1)
            #        = 2.0 - 0.01 * 1.2 * sqrt(0.8) / sqrt(1.8) / 0.8
            self.assertAllClose(1.99, self.evaluate(all_vars[0]))
            # m(1) = beta1 * m(0) + (1-beta1) * grad = 0.2 * 0 + 0.8 * (1 + 2) /
            # 2
            self.assertAllClose(1.2, self.evaluate(all_vars[1]))
            # v(1) = beta2 * v(0) + (1-beta2) * grad^2 = 0.2 * 0 + 0.8 * 2.25
            self.assertAllClose(1.8, self.evaluate(all_vars[2]))

            # second step.
            train_fn()
            # var(1) = var(0) - lr * 2 = 1.98
            self.assertAllClose(1.98, self.evaluate(all_vars[0]))
            # m(2) = beta1 * m(1) + (1-beta1) * grad = 0.2 * 1.2 + 0.8 * 1.5
            self.assertAllClose(1.44, self.evaluate(all_vars[1]))
            # v(2) = beta2 * v(1) + (1-beta2) * grad^2 = 0.2 * 1.8 + 0.8 * 2.25
            self.assertAllClose(2.16, self.evaluate(all_vars[2]))

    @tf.__internal__.distribute.combinations.generate(
        tf.__internal__.test.combinations.combine(
            distribution=[
                tf.__internal__.distribute.combinations.central_storage_strategy_with_two_gpus,  # noqa: E501
            ],
            mode=["graph", "eager"],
        )
    )
    def testOptimizerWithKerasModelAndNumpyArrays(self, distribution):
        with self.cached_session():
            with distribution.scope():
                model = get_model()
                optimizer = gradient_descent.SGD(0.001)
                loss = "mse"
                metrics = ["mae"]
                model.compile(optimizer, loss, metrics=metrics)

            inputs = np.zeros((64, 3), dtype=np.float32)
            targets = np.zeros((64, 4), dtype=np.float32)

            model.fit(
                inputs,
                targets,
                epochs=1,
                batch_size=2,
                verbose=0,
                validation_data=(inputs, targets),
            )
            model.evaluate(inputs, targets)
            model.predict(inputs)


def _replica_id():
    replica_id = tf.distribute.get_replica_context().replica_id_in_sync_group
    if not isinstance(replica_id, tf.Tensor):
        replica_id = tf.constant(replica_id)
    return replica_id


if __name__ == "__main__":
    tf.test.main()
